import os 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torch.distributions.uniform import Uniform
import pickle
from collections import namedtuple
from itertools import chain, repeat
import numpy as np
from e2cnn import gspaces
from e2cnn import nn


class RCNNV1(torch.nn.Module):

    def __init__(self, n_feats = 48):
        super(RCNNV1, self).__init__()
        from e2cnn import nn
        self.r2_act = gspaces.Rot2dOnR2(N=8)

        in_type = nn.FieldType(self.r2_act, 3*[self.r2_act.trivial_repr])
        self.input_type = in_type

        # convolution 1
        out_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        self.block1 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        
        # convolution 2
        in_type = self.block1.out_type
        out_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        self.block2 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        
        self.pool1 = nn.SequentialModule(
            nn.PointwiseAvgPoolAntialiased(out_type, sigma = 0.66, stride=2)
        )

    def forward(self, input: torch.Tensor):
        # wrap the input tensor in a GeometricTensor
        # (associate it with the input type)
        from e2cnn import nn
        x = nn.GeometricTensor(input, self.input_type)

        x = self.block1(x)   
        x = self.block2(x)
        x = self.pool1(x)

        x = x.tensor
        return x 

class RCNNV1_2(torch.nn.Module):

    def __init__(self, n_feats = 48): 
        super(RCNNV1_2, self).__init__()
        from e2cnn import nn
        self.r2_act = gspaces.Rot2dOnR2(N=8)
        
        in_type = nn.FieldType(self.r2_act, 3*[self.r2_act.trivial_repr])
        self.input_type = in_type
        
        # convolution 1
        out_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        self.block1 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        
        # convolution 2
        in_type = self.block1.out_type
        out_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        self.block2 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        
        self.pool1 = nn.SequentialModule(
            nn.PointwiseAvgPoolAntialiased(out_type, sigma = 0.66, stride=2)
        )
        
        # convolution 3
        in_type = self.block2.out_type
        out_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        self.block3 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=3, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        
        
        # convolution 4
        in_type = self.block3.out_type
        out_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        self.block4 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=3, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        self.pool2 = nn.PointwiseAvgPoolAntialiased(out_type, sigma = 0.66, stride=2)

    def forward(self, input: torch.Tensor):
        from e2cnn import nn
        x = nn.GeometricTensor(input, self.input_type)
            
        x = self.block1(x)   
        x = self.block2(x)
        x = self.pool1(x)
        
        x = self.block3(x)
        x = self.block4(x)
        x = self.pool2(x)
        
        x = x.tensor
        return x 
    
class RCNNV1_3(torch.nn.Module):

    def __init__(self, n_feats = 48):
        
        super(RCNNV1_3, self).__init__()
        from e2cnn import nn
        self.r2_act = gspaces.Rot2dOnR2(N=8)
        
        in_type = nn.FieldType(self.r2_act, 3*[self.r2_act.trivial_repr])
        self.input_type = in_type
        
        # convolution 1
        out_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        self.block1 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        
        # convolution 2
        in_type = self.block1.out_type
        out_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        self.block2 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        
        self.pool1 = nn.SequentialModule(
            nn.PointwiseAvgPoolAntialiased(out_type, sigma = 0.66, stride=2)
        )
        
        # convolution 3
        in_type = self.block2.out_type
        out_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        self.block3 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=3, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        
        
        # convolution 5
        in_type = self.block3.out_type
        out_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        self.block4 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=3, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        self.pool2 = nn.PointwiseAvgPoolAntialiased(out_type, sigma = 0.66, stride=2)
        
        in_type = self.block4.out_type
        out_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        
        self.block5 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=3, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        self.block6 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=3, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        self.pool3 = nn.PointwiseAvgPoolAntialiased(out_type, sigma = 0.66, stride=2)

    def forward(self, input: torch.Tensor):
        from e2cnn import nn
        x = nn.GeometricTensor(input, self.input_type)
            
        x = self.block1(x)   
        x = self.block2(x)
        x = self.pool1(x)
        
        x = self.block3(x)
        x = self.block4(x)
        x = self.pool2(x)
        
        x = self.block5(x)
        x = self.block6(x)
        x = self.pool3(x)
        
        x = x.tensor
        return x 
    
class RCNNV1_4(torch.nn.Module):

    def __init__(self, n_feats = 48):
        super(RCNNV1_4, self).__init__()
        from e2cnn import nn
        self.r2_act = gspaces.Rot2dOnR2(N=8)
        in_type = nn.FieldType(self.r2_act, 3*[self.r2_act.trivial_repr])
        
        self.input_type = in_type
        
        # convolution 1
        out_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        self.block1 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        
        # convolution 2
        in_type = self.block1.out_type
        out_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        self.block2 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        
        self.pool1 = nn.SequentialModule(
            nn.PointwiseAvgPoolAntialiased(out_type, sigma = 0.66, stride=2)
        )
        
        # convolution 3
        in_type = self.block2.out_type
        out_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        self.block3 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=3, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        
        
        # convolution 5
        in_type = self.block3.out_type
        out_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        self.block4 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=3, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        self.pool2 = nn.PointwiseAvgPoolAntialiased(out_type, sigma = 0.66, stride=2)
        
        in_type = self.block4.out_type
        out_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        
        self.block5 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=3, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        self.block6 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=3, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        self.pool3 = nn.PointwiseAvgPoolAntialiased(out_type, sigma = 0.66, stride=2)
        
        in_type = self.block6.out_type
        out_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        self.block7 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=3, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        self.block8 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=3, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        self.pool4 = nn.PointwiseAvgPoolAntialiased(out_type, sigma = 0.66, stride=1)

    def forward(self, input: torch.Tensor):
        from e2cnn import nn
        x = nn.GeometricTensor(input, self.input_type)
            
        x = self.block1(x)   
        x = self.block2(x)
        x = self.pool1(x)
        
        x = self.block3(x)
        x = self.block4(x)
        x = self.pool2(x)
        
        x = self.block5(x)
        x = self.block6(x)
        x = self.pool3(x)
        
        x = self.block7(x)
        x = self.block8(x)
        x = self.pool4(x)
        
        x = x.tensor
        return x 